{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RizPOl_DRB-_"
      },
      "source": [
        "```\n",
        "- Copyright 2023 DeepMind Technologies Limited\n",
        "- All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0\n",
        "- All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY).  You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode\n",
        "- Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.\n",
        "- This is not an official Google product\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LVKsiNuyfoNZ"
      },
      "source": [
        "# Cap set\n",
        "\n",
        "This notebook contains\n",
        "\n",
        "1. the *skeleton* we used for FunSearch to discover large cap sets,\n",
        "2. the *functions* discovered by FunSearch that construct large cap sets.\n",
        "\n",
        "## Skeleton\n",
        "\n",
        "The commented-out decorators are just a way to indicate the main entry point of the program (`@funsearch.run`) and the function that *FunSearch* should evolve (`@funsearch.evolve`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3zZ0fAe6flO_"
      },
      "outputs": [],
      "source": [
        "\"\"\"Finds large cap sets.\"\"\"\n",
        "import itertools\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "# @funsearch.run\n",
        "def evaluate(n: int) -\u003e int:\n",
        "  \"\"\"Returns the size of an `n`-dimensional cap set.\"\"\"\n",
        "  capset = solve(n)\n",
        "  return len(capset)\n",
        "\n",
        "\n",
        "def solve(n: int) -\u003e np.ndarray:\n",
        "  \"\"\"Returns a large cap set in `n` dimensions.\"\"\"\n",
        "  all_vectors = np.array(list(itertools.product((0, 1, 2), repeat=n)), dtype=np.int32)\n",
        "\n",
        "  # Powers in decreasing order for compatibility with `itertools.product`, so\n",
        "  # that the relationship `i = all_vectors[i] @ powers` holds for all `i`.\n",
        "  powers = 3 ** np.arange(n - 1, -1, -1)\n",
        "\n",
        "  # Precompute all priorities.\n",
        "  priorities = np.array([priority(tuple(vector), n) for vector in all_vectors])\n",
        "\n",
        "  # Build `capset` greedily, using priorities for prioritization.\n",
        "  capset = np.empty(shape=(0, n), dtype=np.int32)\n",
        "  while np.any(priorities != -np.inf):\n",
        "    # Add a vector with maximum priority to `capset`, and set priorities of\n",
        "    # invalidated vectors to `-inf`, so that they never get selected.\n",
        "    max_index = np.argmax(priorities)\n",
        "    vector = all_vectors[None, max_index]  # [1, n]\n",
        "    blocking = np.einsum('cn,n-\u003ec', (- capset - vector) % 3, powers)  # [C]\n",
        "    priorities[blocking] = -np.inf\n",
        "    priorities[max_index] = -np.inf\n",
        "    capset = np.concatenate([capset, vector], axis=0)\n",
        "\n",
        "  return capset\n",
        "\n",
        "\n",
        "# @funsearch.evolve\n",
        "def priority(el: tuple[int, ...], n: int) -\u003e float:\n",
        "  \"\"\"Returns the priority with which we want to add `element` to the cap set.\"\"\"\n",
        "  return 0.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QY5jPdo-g1fT"
      },
      "source": [
        "By executing the skeleton with the trivial `priority` function in place we can check that the resulting cap sets are far from optimal (e.g. recall that largest known cap set for `n = 9` has size `1082`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 134,
          "status": "ok",
          "timestamp": 1697038278379,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "1cLP6xvzfn1k",
        "outputId": "7b371fd6-ad19-4459-d68d-0ccfa9e2927a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "3 8\n",
            "4 16\n",
            "5 32\n",
            "6 64\n",
            "7 128\n",
            "8 256\n",
            "9 512\n"
          ]
        }
      ],
      "source": [
        "for n in range(3, 9+1):\n",
        "  print(n, evaluate(n))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I9-mf0aThXQl"
      },
      "source": [
        "## Discovered function that builds a $512$-cap in $n = 8$ dimensions\n",
        "\n",
        "This function discovered by FunSearch results in a cap set of size $512$ in $n = 8$ dimensions, thus improving upon the previously known best construction (which had size $496$)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k-k8WyrohG8I"
      },
      "outputs": [],
      "source": [
        "def priority(el: tuple[int, ...], n: int) -\u003e float:\n",
        "  score = n\n",
        "  in_el = 0\n",
        "  el_count = el.count(0)\n",
        "\n",
        "  if el_count == 0:\n",
        "    score += n ** 2\n",
        "    if el[1] == el[-1]:\n",
        "      score *= 1.5\n",
        "    if el[2] == el[-2]:\n",
        "      score *= 1.5\n",
        "    if el[3] == el[-3]:\n",
        "      score *= 1.5\n",
        "  else:\n",
        "    if el[1] == el[-1]:\n",
        "      score *= 0.5\n",
        "    if el[2] == el[-2]:\n",
        "      score *= 0.5\n",
        "\n",
        "  for e in el:\n",
        "    if e == 0:\n",
        "      if in_el == 0:\n",
        "        score *= n * 0.5\n",
        "      elif in_el == el_count - 1:\n",
        "        score *= 0.5\n",
        "      else:\n",
        "        score *= n * 0.5 ** in_el\n",
        "      in_el += 1\n",
        "    else:\n",
        "      score += 1\n",
        "\n",
        "  if el[1] == el[-1]:\n",
        "    score *= 1.5\n",
        "  if el[2] == el[-2]:\n",
        "    score *= 1.5\n",
        "\n",
        "  return score\n",
        "\n",
        "\n",
        "# We call the `solve` function instead of `evaluate` so that we get access to\n",
        "# the cap set itself (rather than just its size), for verification and\n",
        "# inspection purposes.\n",
        "cap_set_n8 = solve(8)\n",
        "assert cap_set_n8.shape == (512, 8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uTFYifHWiEO3"
      },
      "source": [
        "We make use of a helper function to verify that the cap set is indeed valid."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZvQqqXS_iDhY"
      },
      "outputs": [],
      "source": [
        "def is_cap_set(vectors: np.ndarray) -\u003e bool:\n",
        "  \"\"\"Returns whether `vectors` form a valid cap set.\n",
        "\n",
        "  Checking the cap set property naively takes O(c^3 n) time, where c is the size\n",
        "  of the cap set. This function implements a faster check that runs in O(c^2 n).\n",
        "\n",
        "  Args:\n",
        "    vectors: [c, n] array containing c n-dimensional vectors over {0, 1, 2}.\n",
        "  \"\"\"\n",
        "  _, n = vectors.shape\n",
        "\n",
        "  # Convert `vectors` elements into raveled indices (numbers in [0, 3^n) ).\n",
        "  powers = np.array([3 ** j for j in range(n - 1, -1, -1)], dtype=int)  # [n]\n",
        "  raveled = np.einsum('in,n-\u003ei', vectors, powers)  # [c]\n",
        "\n",
        "  # Starting from the empty set, we iterate through `vectors` one by one and at\n",
        "  # each step check that the vector can be inserted into the set without\n",
        "  # violating the defining property of cap set. To make this check fast we\n",
        "  # maintain a vector `is_blocked` indicating for each element of Z_3^n whether\n",
        "  # that element can be inserted into the growing set without violating the cap\n",
        "  # set property.\n",
        "  is_blocked = np.full(shape=3 ** n, fill_value=False, dtype=bool)\n",
        "  for i, (new_vector, new_index) in enumerate(zip(vectors, raveled)):\n",
        "    if is_blocked[new_index]:\n",
        "      return False  # Inserting the i-th element violated the cap set property.\n",
        "    if i \u003e= 1:\n",
        "      # Update which elements are blocked after the insertion of `new_vector`.\n",
        "      blocking = np.einsum(\n",
        "          'nk,k-\u003en',\n",
        "          (- vectors[:i, :] - new_vector[None, :]) % 3, powers)\n",
        "      is_blocked[blocking] = True\n",
        "    is_blocked[new_index] = True  # In case `vectors` contains duplicates.\n",
        "  return True  # All elements inserted without violating the cap set property.\n",
        "\n",
        "\n",
        "assert is_cap_set(cap_set_n8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W74wDTgB0KOn"
      },
      "source": [
        "We can start noticing some regularities in the discovered cap set if we inspect the number of nonzero entries (weights) of each of the 512 vectors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 53,
          "status": "ok",
          "timestamp": 1697038278944,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "_tRWqFAVzb6R",
        "outputId": "f36a956e-3b53-42da-a7e7-0965fad27770"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
            " 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
            " 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
            " 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
            " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
            " 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
            " 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5\n",
            " 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5]\n"
          ]
        }
      ],
      "source": [
        "print(np.count_nonzero(cap_set_n8, axis=1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iOSNsJDQ0a_e"
      },
      "source": [
        "## Explicit construction of a $512$-cap in $n = 8$ dimensions\n",
        "\n",
        "Thanks to discovering this cap set by searching in function space and noticing some regularities like the one above, we were able to manually find the following explicit construction of this new $512$-cap. See the paper's Supplementary Information for more details.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9D3B5-nz1k9x"
      },
      "outputs": [],
      "source": [
        "def build_512_cap() -\u003e list[tuple[int, ...]]:\n",
        "  \"\"\"Returns a cap set of size 512 in `n=8` dimensions.\"\"\"\n",
        "  n = 8\n",
        "  V = list(itertools.product(range(3), repeat=n))\n",
        "  support = lambda v: tuple(i for i in range(n) if v[i] != 0)\n",
        "  reflections = lambda v: sum(1 for i in range(1, n // 2) if v[i] == v[-i])\n",
        "\n",
        "  # Add all 128 weight-8 vectors that have \u003e= 2 reflections.\n",
        "  weight8_vectors = [v for v in V\n",
        "                     if len(support(v)) == 8  # Weight is 8.\n",
        "                     and reflections(v) \u003e= 2]  # At least 2 reflections.\n",
        "\n",
        "  # Add all 128 weight-4 vectors that have specific support.\n",
        "  supports_16 = [(0, 1, 2, 3), (0, 1, 2, 5), (0, 3, 6, 7), (0, 5, 6, 7),\n",
        "                 (1, 3, 4, 6), (1, 4, 5, 6), (2, 3, 4, 7), (2, 4, 5, 7)]\n",
        "  weight4_vectors = [v for v in V\n",
        "                     if support(v) in supports_16]\n",
        "\n",
        "  # Add all 128 weight-4 vectors with specific support and 1 reflection.\n",
        "  supports_8 = [(0, 1, 2, 7), (0, 1, 2, 6), (0, 1, 3, 7), (0, 1, 6, 7),\n",
        "                (0, 1, 5, 7), (0, 2, 3, 6), (0, 2, 6, 7), (0, 2, 5, 6),\n",
        "                (1, 2, 4, 7), (1, 2, 4, 6), (1, 3, 4, 7), (1, 4, 6, 7),\n",
        "                (1, 4, 5, 7), (2, 3, 4, 6), (2, 4, 6, 7), (2, 4, 5, 6)]\n",
        "  weight4_vectors_2 = [v for v in V\n",
        "                       if support(v) in supports_8\n",
        "                       and reflections(v) == 1]  # Exactly 1 reflection.\n",
        "\n",
        "  # Add 128 weight-5 vectors with \u003c= 1 reflections and one more condition.\n",
        "  allowed_zeros = [(0, 4, 7), (0, 2, 4), (0, 1, 4), (0, 4, 6),\n",
        "                   (1, 2, 6), (2, 6, 7), (1, 2, 7), (1, 6, 7)]\n",
        "  weight5_vectors = [\n",
        "      v for v in V\n",
        "      if tuple(i for i in range(n) if v[i] == 0) in allowed_zeros\n",
        "      and reflections(v) \u003c= 1  # At most 1 reflection.\n",
        "      and (v[1] * v[7]) % 3 != 1 and (v[2] * v[6]) % 3 != 1]\n",
        "\n",
        "  return weight8_vectors + weight4_vectors + weight4_vectors_2 + weight5_vectors\n",
        "\n",
        "\n",
        "explicit = np.array(build_512_cap(), dtype=np.int32)\n",
        "assert explicit.shape == (512, 8)\n",
        "assert is_cap_set(explicit)\n",
        "# The explicit construction builds the same cap set as a set (i.e. up to\n",
        "# permutation of rows).\n",
        "assert set(map(tuple, explicit)) == set(map(tuple, cap_set_n8))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jKTDz_ChifB8"
      },
      "source": [
        "## Discovered function that builds a $1082$-cap in $n = 9$ dimensions\n",
        "\n",
        "This matches the previously known best construction, which involves a mathematical argument utilising a special kind of product construction. Comments in the code were added by us."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uOz3hX0YiWKd"
      },
      "outputs": [],
      "source": [
        "def priority(el: tuple[int, ...], n: int) -\u003e float:\n",
        "  el = np.array(el, dtype=np.float32)\n",
        "  weight = (el @ el) % 3  # Weight (mod 3) of the full vector.\n",
        "  a = n // 3\n",
        "  b = n - n // 3\n",
        "  s_1 = (el[:b] @ el[:b]) % 3  # Weight (mod 3) of first two thirds.\n",
        "  s_3 = (2 * (el[:a] @ el[:a])) % 3  # Double norm of first third.\n",
        "  s_4 = (el[:a] @ el[a:b]) % 3  # Cross correlation.\n",
        "  s_5 = np.sum(el[:a] == el[-1]) % 3\n",
        "  return - 3 ** 3 * s_1 + 3 ** 2 * weight + 3 ** 3 * s_3 + 3 ** 2 * s_4 + s_5\n",
        "\n",
        "\n",
        "n = 9\n",
        "cap_set_n9 = solve(n)\n",
        "assert cap_set_n9.shape == (1082, 9)\n",
        "assert is_cap_set(cap_set_n9)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
